# -*- coding: utf-8 -*-
"""
    @Author ljy
    @Date 2020/11/4 17:26
    @Describe 
"""

from pymysql import Connect
from pymysql.cursors import DictCursor
from redis import Redis

_rds = Redis('10.36.174.47', port=6378, decode_responses=True)

class DBConn():
    def __init__(self):
        # cursorclass指定DictCursor
        self.conn = Connect(host='10.36.174.47',
                       port=3306,
                       user='root',
                       password='root',
                       db='apidb',
                       charset='utf8',
                       cursorclass=DictCursor)

    def __enter__(self):
        return self.conn.cursor()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type:
            print('----->操作异常',exc_type,exc_val,exc_tb)
            self.conn.rollback()
        else:
            self.conn.commit()
        return True

    def close(self):
        try:
            self.conn.close()
        except:
            pass

class Dao():
    def __init__(self):
        self.db = DBConn()

    def __del__(self):
        self.db.close()

    def save(self,table, pk='id', **item) -> int:
        sql = 'insert into %s (%s) values(%s)'
        fields = ','.join(item)
        values = ','.join(['%%(%s)s' % k for k in item])
        with self.db as c:
            if c.execute(sql % (table, fields, values), args=item):
                sql = f'select max({pk}) as next_id from {table}'
                c.execute(sql)
                ret = c.fetchone()  # {‘next_id’: 2}
                return ret['next_id']
        return 0


    def update(self,table, pk='id',extra_where_field=None, **item,):
        sql = 'update %s set %s where %s'
        update_fields = ','.join(['%s=%%(%s)s' % (k, k) for k in item if k != pk])
        where = f'{pk}=%({pk})s'
        where = where + (f' and {extra_where_field}=%({extra_where_field})s' if extra_where_field else '')

        with self.db as c:
            c.execute(sql % (table, update_fields, where), args=item)
            return c.rowcount > 0


    def query(self,table, *field, where=None, args=None, one=False):
        sql = 'select %s from %s' % (','.join(field), table)
        if where:
            sql += " where " + where
        with self.db as c:
            c.execute(sql, args=args)
            ret = c.fetchone() if one else c.fetchall()
        return ret


    def delete(self,table,where,args=None):
        with self.db as c:
            # execute()执行insert,update,delete等语句时，返回的是影响的行数
            return c.execute('delete from %s where %s' %(table,where),args=args)

        return 0


